import glfw
import time
from OpenGL.GL import *
from OpenGL.GL.shaders import compileProgram, compileShader

# -----------------------------
# Compute Shader (simple kernel)
# -----------------------------
compute_shader_source = """
#version 430
layout(local_size_x = 256) in;

layout(std430, binding = 0) buffer Data {
    float values[];
};

uniform uint start_index;

void main() {
    uint gid = gl_GlobalInvocationID.x + start_index;
    values[gl_GlobalInvocationID.x] = float(gid % 1024u) * 0.001;
}
"""

# -----------------------------
# HDGL Executor
# -----------------------------
class HDGLExecutor:
    def __init__(self, target_vram_bytes=7_500_000_000):
        if not glfw.init():
            raise RuntimeError("GLFW init failed")

        glfw.window_hint(glfw.VISIBLE, glfw.FALSE)
        self.window = glfw.create_window(1, 1, "hidden", None, None)
        glfw.make_context_current(self.window)

        # Compile compute shader
        self.shader = compileProgram(
            compileShader(compute_shader_source, GL_COMPUTE_SHADER)
        )
        self.vector_size = 16  # vec4 float

        version = glGetString(GL_VERSION).decode()
        renderer = glGetString(GL_RENDERER).decode()
        print("OpenGL version:", version)
        print("GPU renderer:", renderer)

        # Probe max SSBO size that binds cleanly
        self.max_tile_bytes = self._probe_max_tile_bytes()
        print(f"✅ Max SSBO bindable size: {self.max_tile_bytes/1e6:.1f} MB")

        # Split target VRAM into multiple tiles
        num_tiles = max(1, target_vram_bytes // self.max_tile_bytes)
        self.tile_bytes = target_vram_bytes // num_tiles
        self.buffers = []
        for i in range(num_tiles):
            ssbo = glGenBuffers(1)
            glBindBuffer(GL_SHADER_STORAGE_BUFFER, ssbo)
            glBufferData(GL_SHADER_STORAGE_BUFFER, self.tile_bytes, None, GL_DYNAMIC_DRAW)
            err = glGetError()
            if err != GL_NO_ERROR:
                raise RuntimeError(f"SSBO allocation failed on tile {i}, GL error {err}")
            self.buffers.append(ssbo)

        self.active_vectors = (self.tile_bytes // self.vector_size) * num_tiles
        print(f"✅ Allocated {len(self.buffers)} SSBO tiles "
              f"({self.tile_bytes/1e6:.1f} MB each, total ~{target_vram_bytes/1e9:.2f} GB)")

    def _probe_max_tile_bytes(self):
        """Find the largest SSBO size the driver will bind without OOM."""
        test_sizes = [
            2*1024**3,  # 2 GB
            1*1024**3,  # 1 GB
            768*1024**2,
            512*1024**2,
            256*1024**2,
        ]
        ssbo = glGenBuffers(1)
        for size in test_sizes:
            glBindBuffer(GL_SHADER_STORAGE_BUFFER, ssbo)
            glBufferData(GL_SHADER_STORAGE_BUFFER, size, None, GL_DYNAMIC_DRAW)
            glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 0, ssbo)
            err = glGetError()
            if err == GL_NO_ERROR:
                return size
        raise RuntimeError("Could not allocate even 256MB SSBO")

    def process_virtual_lattice(self, virtual_count):
        processed = 0
        start_time = time.time()

        while processed < virtual_count:
            for ssbo in self.buffers:
                glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 0, ssbo)

                glUseProgram(self.shader)
                start_index_location = glGetUniformLocation(self.shader, "start_index")
                glUniform1ui(start_index_location, processed)

                num_vectors = self.tile_bytes // self.vector_size
                num_groups = (num_vectors + 255) // 256
                glDispatchCompute(num_groups, 1, 1)
                glMemoryBarrier(GL_SHADER_STORAGE_BARRIER_BIT)

                processed += num_vectors
                elapsed = time.time() - start_time
                throughput = processed / elapsed if elapsed > 0 else 0
                print(f"Processed {processed}/{virtual_count} virtual vectors "
                      f"| {throughput:,.0f} vec/s | elapsed {elapsed:.2f}s")

                if processed >= virtual_count:
                    break

        print("✅ Completed streaming virtual lattice across tiles")

# -----------------------------
# Main
# -----------------------------
if __name__ == "__main__":
    # Full 16M^3 virtual lattice
    virtual_vectors = 16_777_216 ** 3  # ≈ 4.7e19

    executor = HDGLExecutor(target_vram_bytes=7_500_000_000)  # ~7.5 GB
    executor.process_virtual_lattice(virtual_vectors)
